import numpy as np
import pandas as pd

import json
from config import PATH

class SetEncoder(json.JSONEncoder):
    def default(self,obj):
        if isinstance(obj,set):
            return list(obj)
        return json.JSONEncoder.default(self,obj)

def _decoder(obj):
    if isinstance(obj,list):
        return set(obj)
    return obj
        

def compute_ipc4_combinations(df,N=100,features=None):
    """This function takes a dataframe with columns ('appln_nr','ipc4',dummyfeature1,dummyfeature2,...) an an input
    and then computes ipc4 combinations which have at least N (=100) patents.
    For each IPC4 combination, the total of patents and the total of dummyfeatrureN=1 patents is computed."""

    if features is None:
        features = list(filter(lambda col: col not in ["appln_nr","ipc4"],df.columns))

    print("Computing maps...")
    fmap = {}
    #Computes a map apn->feature for each feature
    for var in features:
        fmap[var] = df.groupby("appln_nr")[var].agg("first")

#    apn = df.groupby("appln_nr").isauto.agg("first")

    # computes totals per ipc4 code
    totals = pd.DataFrame({'totals':df.groupby("ipc4")[features[0]].agg('count')})

    # our base set of ipc_codes
    ipc_codes = set(totals.index[totals.totals>N])

    # compute sets of applications for each ipc code
    print("Computing IPC sets...")
    ipc2apn = {}
    i = 0
    tot = len(ipc_codes)
    if True:
        for ipc in ipc_codes:
            i += 1
            print("{}/{}: {}".format(i,tot,ipc))
            ipc2apn[ipc] = set(df.appln_nr.loc[df.ipc4 == ipc])
        with open("datasets/tmp/ipc2apn.json","w") as f:
            f.write(json.dumps(ipc2apn,cls=SetEncoder))
            print("done.")
    else:
        print("Loading ipc2apn from file...")
        with open("datasets/tmp/ipc2apn.json","r") as f:
            ipc2apn = json.loads(f.read())
        print("building index...")
        for ipc in ipc2apn:
            ipc2apn[ipc] = set(ipc2apn[ipc])
        
        print("done")

    print("Iterating over IPC combinations...")
    # go through all ipc combinations and save combinations with at least N patents
    ipc_comb = set()
    i = 0
    tot = len(ipc_codes)**2
    rows = []        
    for ipc1 in ipc_codes:
        for ipc2 in ipc_codes:
            i += 1
            if ipc1==ipc2 or (ipc2,ipc1) in ipc_comb:
                continue
            intersection = ipc2apn[ipc1].intersection(ipc2apn[ipc2])
            if len(intersection) < N:
                continue

            v = (1+len(features))*[None]
            for (j,var) in enumerate(features):
                v[j] = sum(fmap[var][intersection])
            v[len(features)] = len(intersection)

            print("{}/{}: Found {}-{}".format(i,tot,ipc1,ipc2))
            ipc_comb.add((ipc1,ipc2))
            l = [ipc1,ipc2]
            l.extend(v)
            rows.append(l)

    newdf = pd.DataFrame(rows)
    cols = ["ipc1","ipc2"]
    cols.extend(features)
    cols.append("total")

    newdf.columns = tuple(cols)

    for var in features:
        newdf["share_{}".format(var)] = newdf[var]/newdf.total

    return newdf

if __name__ == "__main__":
    df = pd.read_csv("{}/appln_ipc4.csv".format(PATH),dtype={'appln_nr':str})
    df2 = compute_ipc4_combinations(df)
#    df2.sort_values('share',ascending=False).to_csv("datasets/V2/ipc4_pairs.csv",index=False)
    df2.to_csv("{}/ipc4_pairs.csv".format(PATH),index=False)
